In [ ]:
# Install whoosh and whoosh_utils library
!pip install /kaggle/input/uspto-whoosh-reloaded-2-7-5-patched/Whoosh_Reloaded-2.7.5-py2.py3-none-any.whl
!sed 's:/kaggle/input/whoosh-wheel-2-7-4/Whoosh-2.7.4-py2.py3-none-any.whl:whoosh-reloaded==2.7.5:g' /kaggle/usr/lib/whoosh_utils/whoosh_utils.py > whoosh_utils.py
In [ ]:
import os
os.environ["KERAS_BACKEND"] = "jax"  # or "tensorflow" or "torch"

import keras_nlp
import keras

import numpy as np
import pandas as pd
from tqdm import tqdm
import gc

import re
import whoosh_utils
import whoosh
In [ ]:
print("Keras:", keras.__version__)
print("KerasNLP:", keras_nlp.__version__)
In [ ]:
class CFG:
    seed = 42
    dataset_path = "/kaggle/input/uspto-explainable-ai"
    preset = "gemma_1.1_instruct_2b_en" # name of pretrained Gemma
    input_length = 1024 # max size of input sequence for training
    output_length = 1200 # max size of output sequence
    num_neighbors = 2 # how many neighbour patents to consider
In [ ]:
keras.utils.set_random_seed(CFG.seed)
In [ ]:
# Read the CSV file into a DataFrame with specific columns
test_df = pd.read_csv(f"{CFG.dataset_path}/test.csv")
test_df = test_df.iloc[:, :CFG.num_neighbors+1]
target_cols = list(test_df.columns[1:])

# Merge metadata of the patents
meta_df = pd.read_parquet(f"{CFG.dataset_path}/patent_metadata.parquet")
test_df = test_df.merge(meta_df, on="publication_number", how="left")

# Merge Title and Abstract of the patennts
patent_df = pd.read_parquet("/kaggle/input/uspto-all-patents-after-1975/all_patents.parquet")
test_df = test_df.merge(patent_df, on="publication_number")

# Fill NaN values
test_df["title"] = test_df["title"].fillna("")
test_df["abstract"] = test_df["abstract"].fillna("")

# Merge Title and Abstract of the neighbour patents
for i in range(CFG.num_neighbors):
    test_df = test_df.merge(
        patent_df,
        left_on=target_cols[i],
        right_on="publication_number",
        how="left",
        suffixes=("", f"_{i}"),
    )

    # Fill NaN values
    test_df[f"title_{i}"] = test_df[f"title_{i}"].fillna("")
    test_df[f"abstract_{i}"] = test_df[f"abstract_{i}"].fillna("")

    # Drop extra publication_number column from merges
    test_df = test_df.drop(columns=[f"publication_number_{i}"])

# Reset index order as it will be used later for iteration
test_df = test_df.reset_index(drop=True)

# Clean up memory
del meta_df, patent_df
gc.collect()
In [ ]:
test_df.head()
In [ ]:
prompt_template = "Task:\nAnalyze and compare the given two patent abstracts and titles, and identify the common or similar query keywords that should yield these two patents when searched in the United States Patent and Trademark Office (USPTO) database.\n\nInstructions:\n1. Carefully read and understand the provided 'Patent 1' and 'Patent 2' titles and abstracts below.\n2. Identify the key terms, concepts, and components that are either common or similar in both patent titles and abstracts.\n3. In the 'Keywords' section below, write the common or similar keywords, separating each keyword with a semicolon (';') and a space (' '). Here is an example response, 'keyword1; keyword2; keyword3_1 keyword3_2'.\n4. Do not add any additional narratives or text before or after the keywords.\n\nPatent 1:\n* Title: {title_a}\n* Abstract: {abstract_a}\n\nPatent 2:\n* Title: {title_b}\n* Abstract: {abstract_b}\n\nKeywords:"

chat_template = f"<start_of_turn>user\n{prompt_template}<end_of_turn>\n<start_of_turn>model\n"
In [ ]:
print(prompt_template)
In [ ]:
def create_prompt(row, neighbor_idx):
    prompt = chat_template.format(title_a=row["title"], abstract_a=row["abstract"],
                                  title_b=row[f"title_{neighbor_idx}"], abstract_b=row[f"abstract_{neighbor_idx}"])
    return prompt
In [ ]:
prompt_sample = create_prompt(test_df.iloc[2], 0)
print(prompt_sample)
In [ ]:
# Declare the model
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_1.1_instruct_2b_en")

# Set input length of small to keep the memory and latency cost small
gemma_lm.preprocessor.sequence_length = CFG.input_length
In [ ]:
def generate_keyword(row, neighbor_idx):
    # Check if any title or abstract is an empty string
    fields = [
        "title",
        "abstract",
        f"title_{neighbor_idx}",
        f"abstract_{neighbor_idx}",
    ]
    if all(row[field] == "" for field in fields):
        return [""]

    # Create prompt
    prompt = create_prompt(row, neighbor_idx)

    try:
        # Generate output from model
        output = gemma_lm.generate(prompt, max_length=CFG.output_length)

        # Extract keyword from model output
        keyword = decode_output(output, prompt)
    except:
        keyword = [""]
        
    return keyword


def decode_output(output, prompt):
    # Remove input prompt from model output
    answer = output.replace(prompt, "").strip()

    # Avoid edge case when output_max_length < model_output
    if "Title:" in answer and "Abstract:" in answer:
        return [""]

    # Filter out possible unwanted output text
    for x in ["Keywords:", "solution:", "Solution", "**", "\n\n", "[", "]"]:
        answer = answer.replace(x, "").strip()

    # Create list of keywords using possible delimiters
    for sep in [";\n-", ",\n-", "\n-", ";\n", ";\n*", ",\n*", ",\n", "\n*", ",", ";"]:
        if sep in answer:
            answer = answer.strip(sep).strip().split(sep + " ")

    # Final filtering: remove '*', '.', and any keywords length > 40
    keywords = [x.replace("*", "").replace(".", "") for x in set(answer) if len(x) < 40]
    
    # If there is no keywords found, then enter a empty string as keyword
    if not len(keywords):
        keywords = [""]
        
    return keywords
In [ ]:
# Keywords for all patents
keywords_all = []

for i, row in tqdm(test_df.iterrows(), total=test_df.shape[0]):
    # Keywords for one patent
    keywords = []
    
    # Iteratively create keywords for each (patent, neighbour) pair
    for i in range(CFG.num_neighbors):
        keywords += generate_keyword(row, neighbor_idx=i)
        
    # Remove duplicate keywords
    keywords = list(set(keywords))
    
    # Merge keywords
    keywords_all.append(keywords)
In [ ]:
_ = [print(f"Keywords {i}: {q}", end="\n\n") for i, q in  enumerate(keywords_all[:3])]
In [ ]:
BRS_STOPWORDS = ['an', 'are', 'by', 'for', 'if', 'into', 'is', 'no', 'not', 'of', 'on', 'such',
        'that', 'the', 'their', 'then', 'there', 'these', 'they', 'this', 'to', 'was', 'will', 'and', 'or']
NUMBER_REGEX = re.compile(r'^(\d+|\d{1,3}(,\d{3})*)(\.\d+)?$')

class NumberFilter(whoosh.analysis.Filter):
    def __call__(self, tokens):
        for t in tokens:
            if not NUMBER_REGEX.match(t.text):
                yield t

custom_analyzer = whoosh.analysis.StandardAnalyzer(stoplist=BRS_STOPWORDS) | NumberFilter()
In [ ]:
it = custom_analyzer("device, 1.023, machine, that, learning, there")
[token.text for token in it]
In [ ]:
query_validator = whoosh_utils.QueryValidator()

def validate_query(query):
    query = "ti:device" if not len(query) or not isinstance(query, str) else query
    try:
        query_validator.validate_query(query)
    except:
        query = "ti:device"
    return query
In [ ]:
validate_query("device OR machine") # query is valid
In [ ]:
validate_query("(device OR machine") # query is invalid due to missing ')' thus returns default query
In [ ]:
queries = []

for i, row in tqdm(test_df.iterrows(), total=len(test_df)):
    # Create query from cpc_codes
    cpc = row["cpc_codes"]
    query_cpc = f"cpc:({' OR '.join(cpc[:15])})" if len(cpc) else ""
    
    try:
        # Analyze the keywords
        keywords_str = ", ".join(keywords_all[i])
        tokens = list(set([token.text for token in custom_analyzer(keywords_str)]))
        
        # Reduce the keywords if number of query tokens > 50
        while len(tokens):
            # Create query from keywords
            query_keywords = f"({' OR '.join(tokens)})"
            
            # Merge quries from keywords and cpc_codes
            query_check = f"detd:{query_keywords}" + (f" AND {query_cpc}" if len(query_cpc) else "")
            
            # Return query if number of query tokens is okay
            if whoosh_utils.count_query_tokens(query_check) < 50:
                query = query_check
                break
                
            # Reduce keywords if number query is not okay
            tokens.pop()
    except:
        query = query_cpc
    
    # Final query validation
    query = validate_query(query)
    
    queries.append(query)
In [ ]:
_ = [print(f"Query {i}: {q}", end="\n\n") for i, q in  enumerate(queries[:3])]
In [ ]:
test_df["query"] = queries
pred_df = test_df[["publication_number", "query"]]
pred_df.to_csv("submission.csv", index=False)
pred_df.head()